#!/usr/bin/env python3
# Copyright 2023 The ChromiumOS Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

import copy
import glob
import os
from pathlib import Path
import sys
from typing import Any, Iterable, List, Optional, Union
from impl.common import (
    CROSVM_ROOT,
    TOOLS_ROOT,
    Command,
    Remote,
    quoted,
    Styles,
    argh,
    console,
    chdir,
    cmd,
    record_time,
    run_main,
    sudo_is_passwordless,
    verbose,
    Triple,
)
from impl.test_config import ROOT_TESTS, DO_NOT_RUN, DO_NOT_RUN_AARCH64, DO_NOT_RUN_WIN64, E2E_TESTS
from impl.test_config import DO_NOT_BUILD_RISCV64, DO_NOT_RUN_WINE64
from impl import testvm

rsync = cmd("rsync")
cargo = cmd("cargo")

# Name of the directory used to package all test files.
PACKAGE_NAME = "integration_tests_package"


def join_filters(items: Iterable[str], op: str):
    return op.join(f"({i})" for i in items)


class TestFilter(object):
    """
    Utility structure to join user-provided filter expressions with additional filters

    See https://nexte.st/book/filter-expressions.html
    """

    def __init__(self, expression: str):
        self.expression = expression

    def exclude(self, *exclude_exprs: str):
        return self.subset(f"not ({join_filters(exclude_exprs, '|')})")

    def include(self, *include_exprs: str):
        include_expr = join_filters(include_exprs, "|")
        return TestFilter(f"({self.expression}) | ({include_expr})")

    def subset(self, *subset_exprs: str):
        subset_expr = join_filters(subset_exprs, "|")
        if not self.expression:
            return TestFilter(subset_expr)
        return TestFilter(f"({self.expression}) & ({subset_expr})")

    def to_args(self):
        if not self.expression:
            return
        yield "--filter-expr"
        yield quoted(self.expression)


def configure_cargo(
    cmd: Command, triple: Triple, features: Optional[str], no_default_features: bool
):
    "Configures the provided cmd with cargo arguments and environment needed to build for triple."
    return (
        cmd.with_args(
            "--workspace",
            "--no-default-features" if no_default_features else None,
            f"--features={features}" if features else None,
        )
        .with_color_flag()
        .with_envs(triple.get_cargo_env())
    )


class HostTarget(object):
    def __init__(self, package_dir: Path):
        self.run_cmd = cmd(package_dir / "run.sh").with_color_flag()

    def run_tests(self, extra_args: List[Any]):
        return self.run_cmd.with_args(*extra_args).fg(style=Styles.live_truncated(), check=False)


class SshTarget(object):
    def __init__(self, package_archive: Path, remote: Remote, vm_arch: str):
        console.print("Transfering integration tests package...")
        with record_time("Transfering"):
            remote.scp([package_archive], "")
        with record_time("Unpacking"):
            remote.ssh(cmd("tar xaf", package_archive.name)).fg(style=Styles.live_truncated())

        # Transfer additional libraries for GPU tests.
        lib_dir = ""
        if vm_arch == "x86_64":
            lib_dir = "/usr/lib/x86_64-linux-gnu"
        elif vm_arch == "aarch64":
            lib_dir = "/usr/lib/aarch64-linux-gnu"

        if lib_dir:
            patterns = [
                f"{lib_dir}/libvirglrenderer.*",
                f"{lib_dir}/libgbm.so*",
                f"{lib_dir}/libgfxstream_backend.so*",
            ]
            files_to_transfer: list[Path] = []
            for pattern in patterns:
                files_to_transfer.extend(Path(p) for p in glob.glob(pattern) if Path(p).is_file())

            if files_to_transfer:
                console.print(f"Transfering {vm_arch} libraries for GPU tests...")
                with record_time("Transfering libs"):
                    remote.ssh(cmd(f"sudo mkdir -p {lib_dir}")).fg(style=Styles.live_truncated())
                    remote.scp(files_to_transfer, "")
                    file_names = [Path(f).name for f in files_to_transfer]
                    remote.ssh(cmd("sudo", "mv", *file_names, lib_dir)).fg(
                        style=Styles.live_truncated()
                    )

        self.remote_run_cmd = cmd(f"{PACKAGE_NAME}/run.sh").with_color_flag()
        self.remote = remote

    def run_tests(self, extra_args: List[Any]):
        return self.remote.ssh(self.remote_run_cmd.with_args(*extra_args)).fg(
            style=Styles.live_truncated(),
            check=False,
        )


def check_host_prerequisites(run_root_tests: bool):
    "Check various prerequisites for executing test binaries."
    if os.name == "nt":
        return

    if run_root_tests:
        console.print("Running tests that require root privileges. Refreshing sudo now.")
        cmd("sudo true").fg()

    for device in ["/dev/kvm", "/dev/vhost-vsock"]:
        if not os.access(device, os.R_OK | os.W_OK):
            console.print(f"{device} access is required", style="red")
            sys.exit(1)


def check_build_prerequisites(triple: Triple):
    installed_toolchains = cmd("rustup target list --installed").lines()
    if str(triple) not in installed_toolchains:
        console.print(f"Your host is not configured to build for [green]{triple}[/green]")
        console.print(f"[green]Tip:[/green] Run tests in the dev container with:")
        console.print()
        console.print(
            f"  [blue]$ tools/dev_container tools/run_tests {' '.join(sys.argv[1:])}[/blue]"
        )
        sys.exit(1)


def get_vm_arch(triple: Triple):
    if str(triple) == "x86_64-unknown-linux-gnu":
        return "x86_64"
    elif str(triple) == "aarch64-unknown-linux-gnu":
        return "aarch64"
    elif str(triple) == "riscv64gc-unknown-linux-gnu":
        return "riscv64"
    else:
        raise Exception(f"{triple} is not supported for running tests in a VM.")


@argh.arg("--filter-expr", "-E", type=str, action="append", help="Nextest filter expression.")
@argh.arg(
    "--platform", "-p", help="Which platform to test. (x86_64, aarch64, armhw, mingw64, riscv64)"
)
@argh.arg("--dut", help="Which device to test on. (vm or host)")
@argh.arg("--no-default-features", help="Don't enable default features")
@argh.arg("--no-run", "--build-only", help="Build only, do not run any tests.")
@argh.arg("--no-unit-tests", help="Do not run unit tests.")
@argh.arg("--no-integration-tests", help="Do not run integration tests.")
@argh.arg("--no-strip", help="Do not strip test binaries of debug info.")
@argh.arg("--run-root-tests", help="Enables integration tests that require root privileges.")
@argh.arg(
    "--features",
    help=f"List of comma separated features to be passed to cargo. Defaults to `all-$platform`",
)
@argh.arg("--no-parallel", help="Do not parallelize integration tests. Slower but more stable.")
@argh.arg("--repetitions", help="Repeat all tests, useful for checking test stability.")
@argh.arg("--retries", help="Number of test retries on failure..")
def main(
    filter_expr: List[str] = [],
    platform: Optional[str] = None,
    dut: Optional[str] = None,
    no_default_features: bool = False,
    no_run: bool = False,
    no_unit_tests: bool = False,
    no_integration_tests: bool = False,
    no_strip: bool = False,
    run_root_tests: bool = False,
    features: Optional[str] = None,
    no_parallel: bool = False,
    repetitions: int = 1,
    retries: Optional[int] = None,
):
    """
    Runs all crosvm tests

    For details on how crosvm tests are organized, see https://crosvm.dev/book/testing/index.html

    # Basic Usage

    To run all unit tests for the hosts native architecture:

    $ ./tools/run_tests

    To run all unit tests for another supported architecture using an emulator (e.g. wine64,
    qemu user space emulation).

    $ ./tools/run_tests -p aarch64
    $ ./tools/run_tests -p armhw
    $ ./tools/run_tests -p mingw64

    # Integration Tests

    Integration tests can be run on a built-in virtual machine:

    $ ./tools/run_tests --dut=vm
    $ ./tools/run_tests --dut=vm -p aarch64

    The virtual machine is automatically started for the test process and can be managed via the
    `./tools/x86vm` or `./tools/aarch64vm` tools.

    Integration tests can be run on the host machine as well, but cannot be guaranteed to work on
    all configurations.

    $ ./tools/run_tests --dut=host

    # Test Filtering

    This script supports nextest filter expressions: https://nexte.st/book/filter-expressions.html

    For example to run all tests in `my-crate` and all crates that depend on it:

    $ ./tools/run_tests [--dut=] -E 'rdeps(my-crate)'
    """
    chdir(CROSVM_ROOT)

    if os.name == "posix" and not cmd("which cargo-nextest").success():
        raise Exception("Cannot find cargo-nextest. Please re-run `./tools/setup`")
    elif os.name == "nt" and not cmd("where.exe cargo-nextest.exe").success():
        raise Exception("Cannot find cargo-nextest. Please re-run `./tools/install-deps.ps1`")

    triple = Triple.from_shorthand(platform) if platform else Triple.host_default()

    test_filter = TestFilter(join_filters(filter_expr, "|"))

    if not features and not no_default_features:
        features = triple.feature_flag

    if no_run:
        no_integration_tests = True
        no_unit_tests = True

    # Disable the DUT if integration tests are not run.
    if no_integration_tests:
        dut = None

    # Automatically enable tests that require root if sudo is passwordless
    if not run_root_tests:
        if dut == "host":
            run_root_tests = sudo_is_passwordless()
        elif dut == "vm":
            # The test VMs have passwordless sudo configured.
            run_root_tests = True

    # Print summary of tests and where they will be executed.
    if dut == "host":
        dut_str = "Run on host"
    elif dut == "vm" and os.name == "posix":
        dut_str = f"Run on built-in {get_vm_arch(triple)} vm"
    elif dut == None:
        dut_str = "[yellow]Skip[/yellow]"
    else:
        raise Exception(
            f"--dut={dut} is not supported. Options are --dut=host or --dut=vm (linux only)"
        )

    skip_str = "[yellow]skip[/yellow]"
    unit_test_str = "Run on host" if not no_unit_tests else skip_str
    integration_test_str = dut_str if dut else skip_str
    profile = os.environ.get("NEXTEST_PROFILE", "default")
    console.print(f"Running tests for [green]{triple}[/green]")
    console.print(f"Profile: [green]{profile}[/green]")
    console.print(f"With features: [green]{features}[/green]")
    console.print(f"no-default-features: [green]{no_default_features}[/green]")
    console.print()
    console.print(f"  Unit tests:        [bold]{unit_test_str}[/bold]")
    console.print(f"  Integration tests: [bold]{integration_test_str}[/bold]")
    console.print()

    check_build_prerequisites(triple)

    # Print tips in certain configurations.
    if dut and not run_root_tests:
        console.print(
            "[green]Tip:[/green] Skipping tests that require root privileges. "
            + "Use [bold]--run-root-tests[/bold] to enable them."
        )
    if not dut:
        console.print(
            "[green]Tip:[/green] To run integration tests on a built-in VM: "
            + "Use [bold]--dut=vm[/bold] (preferred)"
        )
        console.print(
            "[green]Tip:[/green] To run integration tests on the host: Use "
            + "[bold]--dut=host[/bold] (fast, but unreliable)"
        )
    if dut == "vm":
        vm_arch = get_vm_arch(triple)
        if vm_arch == "x86_64":
            cli_tool = "tools/x86vm"
        elif vm_arch == "aarch64":
            cli_tool = "tools/aarch64vm"
        else:
            raise Exception(f"Unknown vm arch '{vm_arch}'")
        console.print(
            f"[green]Tip:[/green] The test VM will remain alive between tests. You can manage this VM with [bold]{cli_tool}[/bold]"
        )

    # Prepare the dut for test execution
    if dut == "host":
        check_host_prerequisites(run_root_tests)
    if dut == "vm":
        # Start VM ahead of time but don't wait for it to boot.
        testvm.up(get_vm_arch(triple))

    riscv_exclude_args = []
    if triple == Triple.from_shorthand("riscv64"):
        riscv_exclude_args = ["--exclude=" + s for s in DO_NOT_BUILD_RISCV64]

    nextest_args = [
        f"--profile={profile}" if profile else None,
        "--verbose" if verbose() else None,
    ] + riscv_exclude_args

    nextest_run = configure_cargo(
        cmd("cargo nextest run"), triple, features, no_default_features
    ).with_args(*nextest_args)

    console.print()
    console.rule("Building tests")
    with record_time("Build tests"):
        returncode = nextest_run.with_args("--no-run").fg(
            style=Styles.live_truncated(), check=False
        )
        if returncode != 0:
            sys.exit(returncode)

    if not no_unit_tests:
        unit_test_filter = copy.deepcopy(test_filter).exclude(*E2E_TESTS).include("kind(bench)")
        if triple == Triple.from_shorthand("mingw64") and os.name == "posix":
            unit_test_filter = unit_test_filter.exclude(*DO_NOT_RUN_WINE64)
        console.print()
        console.rule("Running unit tests")
        with record_time("Unit Tests"):
            for i in range(repetitions):
                if repetitions > 1:
                    console.rule(f"Round {i}", style="grey")

                returncode = nextest_run.with_args(
                    "--lib --bins",
                    f"--retries={retries}" if retries is not None else None,
                    *unit_test_filter.to_args(),
                ).fg(style=Styles.live_truncated(), check=False)
                if returncode != 0:
                    sys.exit(returncode)

    if dut:
        console.print()
        console.rule("Building crosvm binary for e2e tests")
        with record_time("Build crosvm binary"):
            cargo_build = configure_cargo(
                cmd("cargo build -p crosvm"), triple, features, no_default_features
            ).with_args(*riscv_exclude_args)
            returncode = cargo_build.fg(style=Styles.live_truncated(), check=False)
            if returncode != 0:
                sys.exit(returncode)

        # Getting crosvm binary's path
        cargo_metadata = (
            cmd("cargo metadata --format-version=1 --no-deps")
            .with_args(
                "--no-default-features" if no_default_features else None,
                f"--features={features}" if features else None,
            )
            .with_color_flag()
            .with_envs(triple.get_cargo_env())
        )
        json_result = cargo_metadata.json()
        if not json_result:
            raise Exception(f"Failed to get cargo metadata")
        crosvm_bin_path = f"{json_result['target_directory']}/{triple}/debug/crosvm"

        package_dir = triple.target_dir / PACKAGE_NAME
        package_archive = package_dir.with_suffix(".tar.zst")
        nextest_package = configure_cargo(
            cmd(TOOLS_ROOT / "nextest_package"), triple, features, no_default_features
        )

        test_exclusions = [*DO_NOT_RUN]
        if not run_root_tests:
            test_exclusions += ROOT_TESTS
        if triple == Triple.from_shorthand("mingw64"):
            test_exclusions += DO_NOT_RUN_WIN64
            if os.name == "posix":
                test_exclusions += DO_NOT_RUN_WINE64
        if triple == Triple.from_shorthand("aarch64"):
            test_exclusions += DO_NOT_RUN_AARCH64
        test_filter = test_filter.exclude(*test_exclusions)

        console.print()
        console.rule("Packaging integration tests")
        with record_time("Packing"):
            nextest_package(
                "--test *",
                f"--bin {crosvm_bin_path}",
                f"-d {package_dir}",
                f"-o {package_archive}" if dut != "host" else None,
                "--no-strip" if no_strip else None,
                *test_filter.to_args(),
                "--verbose" if verbose() else None,
            ).fg(style=Styles.live_truncated())

        target: Union[HostTarget, SshTarget]
        if dut == "host":
            target = HostTarget(package_dir)
        elif dut == "vm":
            vm_arch = get_vm_arch(triple)
            testvm.up(vm_arch, wait=True)
            remote = Remote("localhost", testvm.ssh_opts(vm_arch))
            target = SshTarget(package_archive, remote, vm_arch)

        console.print()
        console.rule("Running integration tests")
        with record_time("Integration tests"):
            for i in range(repetitions):
                if repetitions > 1:
                    console.rule(f"Round {i}", style="grey")
                returncode = target.run_tests(
                    [
                        *test_filter.to_args(),
                        *nextest_args,
                        f"--retries={retries}" if retries is not None else None,
                        "--test-threads=1" if no_parallel else None,
                    ]
                )
                if returncode != 0:
                    if not no_parallel:
                        console.print(
                            "[green]Tip:[/green] Tests may fail when run in parallel on some platforms. "
                            + "Try re-running with `--no-parallel`"
                        )
                    if dut == "host":
                        console.print(
                            f"[yellow]Tip:[/yellow] Running tests on the host may not be reliable. "
                            "Prefer [bold]--dut=vm[/bold]."
                        )
                    sys.exit(returncode)


if __name__ == "__main__":
    run_main(main)
